{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Covariate-Assisted Embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In many network problems, our network might be more than just the information contained in its adjacency matrix - we might have extra information in the form of a set of covariates for each node. Covariate-Assisted Embedding (CASE) uses both these covariates and our adjacency matrix to create and embed a new representation of our network.\n",
    "\n",
    "There are two primary reasons that we might want to explore using node covariates in addition to topological structure. First, they might improve our embedding if the latent structure of our covariates lines up with the latent structure of our adjacency matrix. Second, figuring out what the clusters of an embedding actually mean can sometimes be difficult - and covariates can create a natural structure.\n",
    "\n",
    "To illustrate CASE, we'll use a model in which some of our community information is in the covariates and some is in our adjacency matrix. we’ll generate an SBM with three communities, with the first and second communities indistinguishable, then a set of covariates, with the second and third communities indistinguishable. Using CASE, we can find an embedding that lets us find distinct community structure."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SBM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we sample a 3-block SBM with 1500 nodes, 500 nodes per community. We'll use the following block probability matrix:\n",
    "\n",
    "\\begin{align*}\n",
    "B = \n",
    "\\begin{bmatrix}\n",
    "0.3 & 0.3 & .15 \\\\\n",
    "0.3 & 0.3 & .15 \\\\\n",
    ".15 & .15 & .3\n",
    "\\end{bmatrix}~\n",
    "\\end{align*}\n",
    "\n",
    "Because $B$ has the same probability values in its upper-left $2 \\times 2$ square, we'll see the nodes in communities one and two as a single giant block in our adjacency matrix. Nodes in community 3 are distinct. So, the end result is that we have three communities that we'd like to separate into distinct clusters, but the topological structure in the adjacency matrix can't separate the three groups by itself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(42)\n",
    "\n",
    "import graspologic\n",
    "from graspologic.simulations import sbm\n",
    "\n",
    "# Start with some simple parameters\n",
    "N = 1500  # Total number of nodes\n",
    "n = N // 3  # Nodes per community\n",
    "p, q = .3, .15\n",
    "B = np.array([[p, p, q],\n",
    "              [p, p, q],\n",
    "              [q, q, p]])\n",
    "\n",
    "# Sample from SBM\n",
    "A, labels = sbm([n, n, n], B, return_labels = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here you can see what our adjacency matrix looks like. Notice the giant block in the top-left: this block contains both nodes in both of the first two communities, and they're indistinguishable from each other."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from graspologic.plot import heatmap\n",
    "import seaborn as sns\n",
    "import matplotlib\n",
    "\n",
    "# visualize\n",
    "fig, ax = plt.subplots(figsize=(10,10))\n",
    "\n",
    "def plot_heatmap(A, ax, title=\"3-block SBM (first two blocks indistinguishable)\", show_cbar=True):\n",
    "    cmap = matplotlib.colors.ListedColormap([\"white\", \"black\"])\n",
    "    ax = heatmap(A, cmap=cmap, ax=ax, inner_hier_labels=labels, title=title, center=None)\n",
    "    cbar = ax.collections[0].colorbar\n",
    "    if show_cbar:\n",
    "        cbar.set_ticks([0.25, .75])\n",
    "        cbar.set_ticklabels(['No Edge', 'Edge'])\n",
    "        cbar.ax.set_frame_on(True)\n",
    "    else:\n",
    "        cbar.remove()\n",
    "            \n",
    "        \n",
    "plot_heatmap(A, ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we wanted to embed this graph using LSE or ASE, we'd find the first and second communities layered on top of each other."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "from graspologic.embed import LaplacianSpectralEmbed as LSE\n",
    "from graspologic.utils import to_laplacian\n",
    "from graspologic.plot import pairplot\n",
    "\n",
    "L = to_laplacian(A, form=\"R-DAD\")\n",
    "L_latents = LSE(n_components=2).fit_transform(L)\n",
    "\n",
    "def plot_latents(latent_positions, *, title, labels, ax=None):\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "    plot = sns.scatterplot(x=latent_positions[:, 0], y=latent_positions[:, 1], hue=labels, \n",
    "                           linewidth=0, s=10, ax=ax, palette=\"Set1\")\n",
    "    plot.set_title(title, wrap=True);\n",
    "    ax.axes.xaxis.set_visible(False)\n",
    "    ax.axes.yaxis.set_visible(False)\n",
    "    ax.legend(loc=\"upper right\", title=\"Community\")\n",
    "    \n",
    "    return plot\n",
    "\n",
    "plot = plot_latents(L_latents, title=\"Latent positions from LSE\", \n",
    "             labels=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, we'd have a tough time clustering this. It would be nice if we could use extra information to more clearly distinguish between communities 0 and 1. We don't have this information in our adjacency matrix: it needs to come from our covariates."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": [
     "margin"
    ]
   },
   "source": [
    "### Covariates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we sample a matrix of covariates $Y$. Each node is associated with its own group of 30 covariates. The $i_{th}$ row of $Y$ contains the covariates associated with node $i$.\n",
    "\n",
    "Covariates in community 1 will be drawn from a $Beta(2,5)$ distribution, whereas covariates in communities 2 or 3 will be drawn from a $Beta(2,2)$ distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from scipy.stats import beta\n",
    "\n",
    "def make_community(a, b, n=500):\n",
    "    return beta.rvs(a, b, size=(n, 30))\n",
    "\n",
    "def gen_covariates(n=500):\n",
    "    c1 = make_community(2, 5, n=n)\n",
    "    c2 = make_community(2, 2, n=n)\n",
    "    c3 = make_community(2, 2, n=n)\n",
    "\n",
    "    covariates = np.vstack((c1, c2, c3))\n",
    "    return covariates\n",
    "    \n",
    "\n",
    "# Generate a covariate matrix\n",
    "Y = gen_covariates()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is a visualization of the covariates we just created. \n",
    "\n",
    "On the left is the covariates themselves. The first community is represented by the 500 lighter-colored rows, and the last two are represented by the 1000 darker-colored rows. On the right is a function of the covariates, $\\alpha YY^\\top$, which is used in CASE for embedding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.gridspec import GridSpec\n",
    "\n",
    "# Generate grid spec\n",
    "fig = plt.figure(tight_layout=True, figsize=(10, 6.5))\n",
    "gs = GridSpec(5, 6)\n",
    "fig.add_subplot(gs[:, 0:2])\n",
    "fig.add_subplot(gs[:, 2:])\n",
    "axs = fig.axes\n",
    "\n",
    "# Plot heatmaps\n",
    "Y_ax = sns.heatmap(Y, ax=axs[0], cmap=\"rocket_r\", cbar=False, yticklabels=500)\n",
    "Y_ax.set(title=r\"$Y$\", xticks=[], \n",
    "       ylabel=\"Nodes\",\n",
    "       xlabel=\"Covariates\");\n",
    "\n",
    "\n",
    "YYt = Y@Y.T\n",
    "aYYt = heatmap(YYt, title=r\"$YY^\\top$\", ax=axs[1], cbar=False);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we embed the information contained in this matrix of covariates, we can see the reverse situation as before - the first community is separate, but the last two are overlayed on top of each other."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.utils.extmath import randomized_svd\n",
    "\n",
    "def embed(matrix, *, dimension):\n",
    "    latents, _, _ = randomized_svd(matrix, n_components=dimension)\n",
    "    return latents\n",
    "\n",
    "Y_latents = embed(Y, dimension=2)\n",
    "\n",
    "plot_latents(Y_latents, title=\"Embedding from covariates\", \n",
    "             labels=labels);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of the first and second communities being indistinguishable, the second and third now are. We'd like to see full separation between all three communities, so we need some kind of representation of our network that allows us to use both the information in the adjacency matrix and the information in the covariates."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CASE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<i>Covariate-Assisted Spectral Embedding</i> is a simple way of combining our network and our covariates into a single model. In the most straightforward version of CASE, we sum the network's regularized Laplacian matrix $L$ and a function of our covariate matrix $YY^\\top$. Here, $Y$ is just our covariate matrix, in which row $i$ contains the covariates associated with node $i$.\n",
    "\n",
    "$$\n",
    "L + \\alpha YY^\\top\n",
    "$$\n",
    "\n",
    "$\\alpha$ is multiplied by $YY^\\top$ so that both matrices contribute an equal amount of useful information to the embedding. $\\alpha$ defaults to the ratio of the largest eigenvalues of $L$ and $YY^\\top$:\n",
    "\n",
    "$$\n",
    "\\alpha = \\frac{\\lambda_1 (L)}{\\lambda_1 (YY^\\top)}\n",
    "$$\n",
    "\n",
    "This is the weight that causes $L$ and $YY^\\top$ to contribute the same amount of information in their leading eigenspaces. \n",
    "\n",
    "Below, you can see graspologic's embedding of the weighted sum of the above two matrices. As you can see, the 3-community structure has been recovered by using both covariate information and topological information. Note that CASE accepts an adjacency matrix as its input into its `fit` method, not a Laplacian matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.embed import CovariateAssistedEmbed as CASE\n",
    "\n",
    "case = CASE(assortative=True, n_components=2)\n",
    "latents = case.fit_transform(graph=A, covariates=Y)\n",
    "plot_latents(latents, title=r\"CASE embedding\", labels=labels);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exploring Possible Weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Rather than simply letting the weight $\\alpha$ be set automatically, you can also set a custom value. Below is a comparison of potential $\\alpha$ values in our experimental setup. There are 9 possible $\\alpha$ values, ranging between $10^{-5}$ and 100."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10, 10))\n",
    "for a, ax in zip(np.geomspace(10e-5, 100, num=9), axs.flat):\n",
    "    case = CASE(alpha=a, assortative=True, n_components=2)\n",
    "    latents = case.fit_transform(graph=A, covariates=Y)\n",
    "    plot_latents(latents, title=f\"weight: {a:.3f}\", labels=labels, ax=ax)\n",
    "    ax.get_legend().remove()\n",
    "\n",
    "fig.suptitle(r\"Comparison of embeddings for differents $\\alpha$ values on $XX^\\top$\", \n",
    "             y=1, fontsize=25);\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we were to set the weight manually, it looks like we'd want a weight around 0.5. In practice, if the default weight is producing undesirable results, a shrewd data scientist could find a good weight by clustering with k-means or a GMM, then performing a line search (for example, `minimize_scalar` in scipy's [optimize](https://docs.scipy.org/doc/scipy/reference/optimize.html) module) on some metric which optimizes cluster separation, like BIC or sklearn's [silhouette score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.silhouette_score.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Non-Assortative Graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If your graph is **non-assortative** - meaning, the between-block probabilities are greater than the within-block probabilities - it's better to square our Laplacian. This gets rid of a lot of annoying negative eigenvalues, and we end up embedding $LL + aYY^\\top$. Below, you can see the embedding in the non-assortative case. In practice, if you don't know whether your graph is assortative or non-assortative, you can try both algorithms and use whichever one works best."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate a non-assortative adjacency matrix\n",
    "p, q = .15, .3\n",
    "B = np.array([[p, p, q],\n",
    "              [p, p, q],\n",
    "              [q, q, p]])\n",
    "A, labels = sbm([n, n, n], B, return_labels = True)\n",
    "\n",
    "# embed and plot\n",
    "case = CASE(assortative=False, n_components=2)\n",
    "latents = case.fit_transform(graph=A, covariates=Y)\n",
    "plot_latents(latents, title=\"Embedding in the non-assortative case\", labels=labels);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "toc-autonumbering": false,
  "toc-showcode": false,
  "toc-showmarkdowntxt": false,
  "toc-showtags": false
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
